Skip to content

Conversation

JC-ut0
Copy link
Contributor

@JC-ut0 JC-ut0 commented Aug 14, 2025

What this PR does / why we need it?

  1. MTP supports V1 scheduler
  2. Refactor attn metadata build

Does this PR introduce any user-facing change?

How was this patch tested?

  • v0.9.1-dev
  • A3 [TP16] [DP4 TP4]
  • A3 1P1D

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the attention metadata handling by introducing a centralized AscendCommonAttentionMetadata dataclass. This is a good architectural improvement that centralizes logic and reduces code duplication. However, the review identified several critical issues related to this refactoring, including incorrect tensor slicing and initialization that could lead to runtime errors or incorrect behavior. Specifically, there are bugs in build_dummy_metadata in attention_v1.py and in the build method of mla_v1.py's metadata builder. There are also some redundant code assignments that should be cleaned up.

block_table[:num_reqs, :self.runner.max_num_blocks_per_req] = (
block_table[:num_reqs])

query_start_loc = common_attn_metadata.query_start_loc
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The common_attn_metadata used here is initialized without a query_start_loc, causing common_attn_metadata.query_start_loc to be None. This None value is then passed to the AscendMetadata constructor, which will lead to a runtime error as query_start_loc is a required tensor. To fix this, a dummy query_start_loc tensor should be created for the decode-only state.

Suggested change
query_start_loc = common_attn_metadata.query_start_loc
query_start_loc = torch.arange(0, num_reqs + 1, dtype=torch.int32, device=block_table.device)

actual_seq_lengths_q = query_start_loc[1:num_decodes+1].tolist()
max_seq_lens = seq_lens[:num_decodes].max().item()
seq_lens = seq_lens[:num_decodes]
input_positions = input_positions[:num_decodes]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The input_positions tensor is incorrectly sliced using num_decodes (the number of decode requests). It should be sliced with num_decode_tokens (the number of tokens in decode requests) to select the correct positions for the decode phase. This is a critical bug that will lead to incorrect behavior.

Suggested change
input_positions = input_positions[:num_decodes]
input_positions = input_positions[:num_decode_tokens]

self.device, non_blocking=True)
attn_mask = common_attn_metadata.attn_mask
attn_state = common_attn_metadata.attn_state
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:num_reqs + 1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This line is redundant as query_start_loc_cpu is already assigned with the same value on line 183. Please remove this duplicate assignment to improve code clarity and avoid potential confusion.

Comment on lines +396 to 407
input_positions = self.runner.positions_cpu[:num_tokens].to(
device, non_blocking=True).long()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The input_positions tensor is assigned here from a CPU tensor and then immediately overwritten on line 398 with a device tensor from common_attn_metadata.positions. This first assignment is redundant and can be removed for clarity and to avoid unnecessary operations.

@JC-ut0 JC-ut0 changed the title [v0.9.1] Refactor mla [v0.9.1] Refactor attn metadata build Aug 14, 2025
Copy link

This pull request has conflicts, please resolve those before we can evaluate the pull request.

):
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.device = device
self.runner = runner
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove runner

@JC-ut0 JC-ut0 force-pushed the v0.9.1-dev branch 2 times, most recently from 619fff4 to 96a8bfb Compare August 16, 2025 02:11
@github-actions github-actions bot added the documentation Improvements or additions to documentation label Aug 16, 2025
Copy link

This pull request has conflicts, please resolve those before we can evaluate the pull request.

@github-actions github-actions bot removed the documentation Improvements or additions to documentation label Aug 16, 2025
@github-actions github-actions bot added the documentation Improvements or additions to documentation label Aug 16, 2025
@JC-ut0 JC-ut0 changed the title [v0.9.1] Refactor attn metadata build [v0.9.1] MTP supports V1 scheduler Aug 16, 2025
Copy link

This pull request has conflicts, please resolve those before we can evaluate the pull request.

Signed-off-by: xuyexiong <[email protected]>
@github-actions github-actions bot removed merge-conflicts documentation Improvements or additions to documentation labels Aug 16, 2025
@ganyi1996ppo ganyi1996ppo merged commit fad57a6 into vllm-project:v0.9.1-dev Aug 16, 2025
17 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants